/*
 * THIS IS A GENERATED FILE!  DO NOT CHANGE THIS FILE!  CHANGE THE
 * CORRESPONDING TEMPLATE FILE, PLEASE!
 */

#include <gtest/gtest.h>
#include <shogun/base/some.h>
#include <shogun/kernel/GaussianKernel.h>
#include <shogun/machine/Machine.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/labels/Labels.h>
#include <shogun/labels/BinaryLabels.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/io/SerializableAsciiFile.h>
#include <shogun/io/SerializableHdf5File.h>
#include <shogun/io/CSVFile.h>
#include <shogun/io/SGIO.h>
#include <shogun/machine/gp/ExactInferenceMethod.h>
#include <shogun/machine/gp/GaussianLikelihood.h>
#include <shogun/machine/gp/ProbitLikelihood.h>
#include <shogun/machine/gp/SingleLaplaceInferenceMethod.h>
#include <shogun/machine/gp/ZeroMean.h>
#include "environments/LinearTestEnvironment.h"
#include "environments/RegressionTestEnvironment.h"
#include "environments/MultiLabelTestEnvironment.h"
#include "utils/Utils.h"

using namespace shogun;

extern LinearTestEnvironment* linear_test_env;
extern MultiLabelTestEnvironment* multilabel_test_env;
extern RegressionTestEnvironment* regression_test_env;

class TrainedModelSerializationTest : public ::testing::Test
{
protected:
	virtual void SetUp() {}

	virtual void TearDown()
	{
		SG_UNREF(train_feats)
		SG_UNREF(test_feats)
		SG_UNREF(train_labels)
	}

	void load_data(EProblemType pt)
	{
		switch (pt)
		{
			case PT_BINARY:
			case PT_CLASS:
			{
				std::shared_ptr<GaussianCheckerboard> mock_data =
					linear_test_env->getBinaryLabelData();
				train_feats = mock_data->get_features_train();
				test_feats = mock_data->get_features_test();
				train_labels = mock_data->get_labels_train();
				break;
			}

			case PT_MULTICLASS:
			{
				std::shared_ptr<GaussianCheckerboard> mock_data =
					multilabel_test_env->getMulticlassFixture();
				train_feats = mock_data->get_features_train();
				test_feats = mock_data->get_features_test();
				train_labels = mock_data->get_labels_train();
				break;
			}

			case PT_REGRESSION:
				train_feats = regression_test_env->get_features_train();
				test_feats = regression_test_env->get_features_test();
				train_labels = regression_test_env->get_labels_train();
				break;

			default:
				SG_SERROR("Unsupported problem type: %d\n", pt);
				FAIL();
		}

		SG_REF(train_feats)
		SG_REF(test_feats)
		SG_REF(train_labels)
	}

	CDenseFeatures<float64_t> *train_feats, *test_feats;
	CLabels *train_labels;
};

bool serialize_machine(CMachine* machine, std::string &filename, bool store_model_features=false)
{
	std::string class_name = machine->get_name();
	filename = "shogun-unittest-trained-model-serialization-" + class_name + ".XXXXXX";
	generate_temp_filename(const_cast<char*>(filename.c_str()));

	CSerializableHdf5File *file=new CSerializableHdf5File(filename.c_str(), 'w');
	machine->set_store_model_features(store_model_features);
	bool save_success=machine->save_serializable(file);
	file->close();
	SG_FREE(file);

	return save_success;
}

bool deserialize_machine(CMachine *machine, std::string filename)
{
	CSerializableHdf5File *file=new CSerializableHdf5File(filename.c_str(), 'r');
	bool load_success=machine->load_serializable(file);

	file->close();
	SG_FREE(file);
	int delete_success=unlink(filename.c_str());

	return load_success && (delete_success == 0);
}

const float64_t accuracy=1e-7;

{% macro machine_test(class) -%}
TEST_F(TrainedModelSerializationTest, {{class}})
{
	auto machine=some<{{class}}>();
	load_data(machine->get_machine_problem_type());

	machine->set_features(train_feats);
	machine->set_labels(train_labels);

	bool train_success=machine->train();
	ASSERT_TRUE(train_success);

	/* to avoid serialization of the data */
//	machine->set_features(NULL);
//	machine->set_labels(NULL);

	auto predictions=wrap<CLabels>(machine->apply(test_feats));

	std::string filename;
	ASSERT_TRUE(serialize_machine(machine, filename));

	auto deserialized_machine=some<{{class}}>();
	ASSERT_TRUE(deserialize_machine(deserialized_machine, filename));

	auto deserialized_predictions=wrap<CLabels>(deserialized_machine->apply(test_feats));
	ASSERT(predictions->equals(deserialized_predictions, accuracy, true))
}
{%- endmacro %}

{% macro kernel_machine_test(class) -%}
{% for store_model_features in ["true", "false"] -%}
{% if store_model_features == "true" -%}
{% set test_name = class + "_store_model_features" -%}
{% else -%}
{% set test_name = class -%}
{% endif -%}
TEST_F(TrainedModelSerializationTest, {{test_name}})
{
	auto machine=some<{{class}}>();
	load_data(machine->get_machine_problem_type());

	CGaussianKernel *kernel=new CGaussianKernel(2.0);
	kernel->init(train_feats, train_feats);
	machine->set_kernel(kernel);
	machine->set_labels(train_labels);

	bool train_success=machine->train();
	ASSERT_TRUE(train_success);

	auto predictions=Some<CLabels>(machine->apply(test_feats));

	std::string filename;
	ASSERT_TRUE(serialize_machine(machine, filename, {{store_model_features}}));

	auto deserialized_machine=some<{{class}}>();
	ASSERT_TRUE(deserialize_machine(deserialized_machine, filename));

	auto deserialized_predictions=Some<CLabels>(deserialized_machine->apply(test_feats));
	ASSERT(predictions->equals(deserialized_predictions, accuracy, true))
}
{% endfor %}
{%- endmacro %}
{%
set macros = {
	'CLinearMachine': machine_test,
	'CNativeMulticlassMachine': machine_test,
	'CLinearMulticlassMachine': machine_test,
	'CKernelMachine': kernel_machine_test,
	'CKernelMulticlassMachine': kernel_machine_test}
%}
{% for b, m in machines.items() -%}
{% for name, attrs in m.items() -%}
#include <{{attrs['include']}}>
{{ macros[b](name) }}
{% endfor %}
{% endfor %}